{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. Skip-gram with naiive softmax "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "I recommend you take a look at these material first."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture2.pdf\n",
    "* https://arxiv.org/abs/1301.3781\n",
    "* http://mccormickml.com/2016/04/19/word2vec-tutorial-the-skip-gram-model/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "import nltk\n",
    "import random\n",
    "import numpy as np\n",
    "from collections import Counter\n",
    "flatten = lambda l: [item for sublist in l for item in sublist]\n",
    "random.seed(1024)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.3.0.post4\n",
      "3.2.4\n"
     ]
    }
   ],
   "source": [
    "print(torch.__version__)\n",
    "print(nltk.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "USE_CUDA = torch.cuda.is_available()\n",
    "gpus = [0]\n",
    "torch.cuda.set_device(gpus[0])\n",
    "\n",
    "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n",
    "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n",
    "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def getBatch(batch_size, train_data):\n",
    "    random.shuffle(train_data)\n",
    "    sindex = 0\n",
    "    eindex = batch_size\n",
    "    while eindex < len(train_data):\n",
    "        batch = train_data[sindex: eindex]\n",
    "        temp = eindex\n",
    "        eindex = eindex + batch_size\n",
    "        sindex = temp\n",
    "        yield batch\n",
    "    \n",
    "    if eindex >= len(train_data):\n",
    "        batch = train_data[sindex:]\n",
    "        yield batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def prepare_sequence(seq, word2index):\n",
    "    idxs = list(map(lambda w: word2index[w] if word2index.get(w) is not None else word2index[\"<UNK>\"], seq))\n",
    "    return Variable(LongTensor(idxs))\n",
    "\n",
    "def prepare_word(word, word2index):\n",
    "    return Variable(LongTensor([word2index[word]]) if word2index.get(word) is not None else LongTensor([word2index[\"<UNK>\"]]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data load and Preprocessing "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load corpus : Gutenberg corpus"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you don't have gutenberg corpus, you can download it first using nltk.download()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['austen-emma.txt',\n",
       " 'austen-persuasion.txt',\n",
       " 'austen-sense.txt',\n",
       " 'bible-kjv.txt',\n",
       " 'blake-poems.txt',\n",
       " 'bryant-stories.txt',\n",
       " 'burgess-busterbrown.txt',\n",
       " 'carroll-alice.txt',\n",
       " 'chesterton-ball.txt',\n",
       " 'chesterton-brown.txt',\n",
       " 'chesterton-thursday.txt',\n",
       " 'edgeworth-parents.txt',\n",
       " 'melville-moby_dick.txt',\n",
       " 'milton-paradise.txt',\n",
       " 'shakespeare-caesar.txt',\n",
       " 'shakespeare-hamlet.txt',\n",
       " 'shakespeare-macbeth.txt',\n",
       " 'whitman-leaves.txt']"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nltk.corpus.gutenberg.fileids()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "corpus = list(nltk.corpus.gutenberg.sents('melville-moby_dick.txt'))[:100] # sampling sentences for test\n",
    "corpus = [[word.lower() for word in sent] for sent in corpus]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Extract Stopwords from unigram distribution's tails"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "word_count = Counter(flatten(corpus))\n",
    "border = int(len(word_count) * 0.01) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "stopwords = word_count.most_common()[:border] + list(reversed(word_count.most_common()))[:border]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "stopwords = [s[0] for s in stopwords]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[',', '.', 'the', 'of', 'and', 'baleine', '--(', 'fat', 'oil', 'boiling']"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "stopwords"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "vocab = list(set(flatten(corpus)) - set(stopwords))\n",
    "vocab.append('<UNK>')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "592 583\n"
     ]
    }
   ],
   "source": [
    "print(len(set(flatten(corpus))), len(vocab))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "word2index = {'<UNK>' : 0} \n",
    "\n",
    "for vo in vocab:\n",
    "    if word2index.get(vo) is None:\n",
    "        word2index[vo] = len(word2index)\n",
    "\n",
    "index2word = {v:k for k, v in word2index.items()} "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare train data "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "window data example"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"../images/01.skipgram-prepare-data.png\">\n",
    "<center>borrowed image from http://mccormickml.com/2016/04/19/word2vec-tutorial-the-skip-gram-model/</center>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "WINDOW_SIZE = 3\n",
    "windows = flatten([list(nltk.ngrams(['<DUMMY>'] * WINDOW_SIZE + c + ['<DUMMY>'] * WINDOW_SIZE, WINDOW_SIZE * 2 + 1)) for c in corpus])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('<DUMMY>', '<DUMMY>', '<DUMMY>', '[', 'moby', 'dick', 'by')"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "windows[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[('[', 'moby'), ('[', 'dick'), ('[', 'by'), ('moby', '['), ('moby', 'dick'), ('moby', 'by')]\n"
     ]
    }
   ],
   "source": [
    "train_data = []\n",
    "\n",
    "for window in windows:\n",
    "    for i in range(WINDOW_SIZE * 2 + 1):\n",
    "        if i == WINDOW_SIZE or window[i] == '<DUMMY>': \n",
    "            continue\n",
    "        train_data.append((window[WINDOW_SIZE], window[i]))\n",
    "\n",
    "print(train_data[:WINDOW_SIZE * 2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "X_p = []\n",
    "y_p = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('[', 'moby')"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "for tr in train_data:\n",
    "    X_p.append(prepare_word(tr[0], word2index).view(1, -1))\n",
    "    y_p.append(prepare_word(tr[1], word2index).view(1, -1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_data = list(zip(X_p, y_p))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7606"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Modeling"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"../images/01.skipgram-objective.png\">\n",
    "<center>borrowed image from http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture2.pdf</center>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Skipgram(nn.Module):\n",
    "    \n",
    "    def __init__(self, vocab_size, projection_dim):\n",
    "        super(Skipgram,self).__init__()\n",
    "        self.embedding_v = nn.Embedding(vocab_size, projection_dim)\n",
    "        self.embedding_u = nn.Embedding(vocab_size, projection_dim)\n",
    "\n",
    "        self.embedding_v.weight.data.uniform_(-1, 1) # init\n",
    "        self.embedding_u.weight.data.uniform_(0, 0) # init\n",
    "        #self.out = nn.Linear(projection_dim,vocab_size)\n",
    "    def forward(self, center_words,target_words, outer_words):\n",
    "        center_embeds = self.embedding_v(center_words) # B x 1 x D\n",
    "        target_embeds = self.embedding_u(target_words) # B x 1 x D\n",
    "        outer_embeds = self.embedding_u(outer_words) # B x V x D\n",
    "        \n",
    "        scores = target_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2) # Bx1xD * BxDx1 => Bx1\n",
    "        norm_scores = outer_embeds.bmm(center_embeds.transpose(1, 2)).squeeze(2) # BxVxD * BxDx1 => BxV\n",
    "        \n",
    "        nll = -torch.mean(torch.log(torch.exp(scores)/torch.sum(torch.exp(norm_scores), 1).unsqueeze(1))) # log-softmax\n",
    "        \n",
    "        return nll # negative log likelihood\n",
    "    \n",
    "    def prediction(self, inputs):\n",
    "        embeds = self.embedding_v(inputs)\n",
    "        \n",
    "        return embeds "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "EMBEDDING_SIZE = 30\n",
    "BATCH_SIZE = 256\n",
    "EPOCH = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "losses = []\n",
    "model = Skipgram(len(word2index), EMBEDDING_SIZE)\n",
    "if USE_CUDA:\n",
    "    model = model.cuda()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 0, mean_loss : 6.20\n",
      "Epoch : 10, mean_loss : 4.38\n",
      "Epoch : 20, mean_loss : 3.48\n",
      "Epoch : 30, mean_loss : 3.31\n",
      "Epoch : 40, mean_loss : 3.26\n",
      "Epoch : 50, mean_loss : 3.24\n",
      "Epoch : 60, mean_loss : 3.22\n",
      "Epoch : 70, mean_loss : 3.22\n",
      "Epoch : 80, mean_loss : 3.21\n",
      "Epoch : 90, mean_loss : 3.20\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(EPOCH):\n",
    "    for i, batch in enumerate(getBatch(BATCH_SIZE, train_data)):\n",
    "        \n",
    "        inputs, targets = zip(*batch)\n",
    "        \n",
    "        inputs = torch.cat(inputs) # B x 1\n",
    "        targets = torch.cat(targets) # B x 1\n",
    "        vocabs = prepare_sequence(list(vocab), word2index).expand(inputs.size(0), len(vocab))  # B x V\n",
    "        model.zero_grad()\n",
    "\n",
    "        loss = model(inputs, targets, vocabs)\n",
    "        \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "   \n",
    "        losses.append(loss.data.tolist()[0])\n",
    "\n",
    "    if epoch % 10 == 0:\n",
    "        print(\"Epoch : %d, mean_loss : %.02f\" % (epoch,np.mean(losses)))\n",
    "        losses = []"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def word_similarity(target, vocab):\n",
    "    if USE_CUDA:\n",
    "        target_V = model.prediction(prepare_word(target, word2index))\n",
    "    else:\n",
    "        target_V = model.prediction(prepare_word(target, word2index))\n",
    "    similarities = []\n",
    "    for i in range(len(vocab)):\n",
    "        if vocab[i] == target: continue\n",
    "        \n",
    "        if USE_CUDA:\n",
    "            vector = model.prediction(prepare_word(list(vocab)[i], word2index))\n",
    "        else:\n",
    "            vector = model.prediction(prepare_word(list(vocab)[i], word2index))\n",
    "        cosine_sim = F.cosine_similarity(target_V, vector).data.tolist()[0] \n",
    "        similarities.append([vocab[i], cosine_sim])\n",
    "    return sorted(similarities, key=lambda x: x[1], reverse=True)[:10] # sort by similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'least'"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test = random.choice(list(vocab))\n",
    "test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[['at', 0.8147411346435547],\n",
       " ['every', 0.7143548130989075],\n",
       " ['case', 0.6975079774856567],\n",
       " ['secure', 0.6121522188186646],\n",
       " ['heart', 0.5974172949790955],\n",
       " ['including', 0.5867112278938293],\n",
       " ['please', 0.5557640194892883],\n",
       " ['has', 0.5536234974861145],\n",
       " ['while', 0.5366998314857483],\n",
       " ['you', 0.509368896484375]]"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "word_similarity(test, vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
